{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Pyspark Data Wrangling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from pyspark import SparkConf, SparkContext\n",
    "## set up spark context\n",
    "from pyspark.sql import SQLContext\n",
    "sc = SparkContext()\n",
    "sqlContext = SQLContext(sc)\n",
    "## set up  SparkSession\n",
    "from pyspark.sql import SparkSession\n",
    "\n",
    "spark = SparkSession \\\n",
    "    .builder \\\n",
    "    .appName(\"Python Spark Data Wrangling\") \\\n",
    "    .config(\"spark.some.config.option\", \"some-value\") \\\n",
    "    .getOrCreate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+---+--------+-------------+-------------+\n",
      "| id|category|categoryIndex|  categoryVec|\n",
      "+---+--------+-------------+-------------+\n",
      "|  0|       a|          0.0|(2,[0],[1.0])|\n",
      "|  1|       b|          2.0|    (2,[],[])|\n",
      "|  2|       c|          1.0|(2,[1],[1.0])|\n",
      "|  3|       a|          0.0|(2,[0],[1.0])|\n",
      "|  4|       a|          0.0|(2,[0],[1.0])|\n",
      "|  5|       c|          1.0|(2,[1],[1.0])|\n",
      "+---+--------+-------------+-------------+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from pyspark.ml.feature import OneHotEncoder, StringIndexer\n",
    "\n",
    "df = spark.createDataFrame([\n",
    "    (0, \"a\"),\n",
    "    (1, \"b\"),\n",
    "    (2, \"c\"),\n",
    "    (3, \"a\"),\n",
    "    (4, \"a\"),\n",
    "    (5, \"c\")\n",
    "], [\"id\", \"category\"])\n",
    "\n",
    "stringIndexer = StringIndexer(inputCol=\"category\", outputCol=\"categoryIndex\")\n",
    "model = stringIndexer.fit(df)\n",
    "indexed = model.transform(df)\n",
    "\n",
    "encoder = OneHotEncoder(inputCol=\"categoryIndex\", outputCol=\"categoryVec\")\n",
    "encoded = encoder.transform(indexed)\n",
    "encoded.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'Map' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-4-9ac986e7984d>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mcategoricalFeaturesInfo\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mMap\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mInt\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m: name 'Map' is not defined"
     ]
    }
   ],
   "source": [
    "categoricalFeaturesInfo = sqlContext.[Int]((1,3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Chose 351 categorical features: 645, 69, 365, 138, 101, 479, 333, 249, 0, 555, 666, 88, 170, 115, 276, 308, 5, 449, 120, 247, 614, 677, 202, 10, 56, 533, 142, 500, 340, 670, 174, 42, 417, 24, 37, 25, 257, 389, 52, 14, 504, 110, 587, 619, 196, 559, 638, 20, 421, 46, 93, 284, 228, 448, 57, 78, 29, 475, 164, 591, 646, 253, 106, 121, 84, 480, 147, 280, 61, 221, 396, 89, 133, 116, 1, 507, 312, 74, 307, 452, 6, 248, 60, 117, 678, 529, 85, 201, 220, 366, 534, 102, 334, 28, 38, 561, 392, 70, 424, 192, 21, 137, 165, 33, 92, 229, 252, 197, 361, 65, 97, 665, 583, 285, 224, 650, 615, 9, 53, 169, 593, 141, 610, 420, 109, 256, 225, 339, 77, 193, 669, 476, 642, 637, 590, 679, 96, 393, 647, 173, 13, 41, 503, 134, 73, 105, 2, 508, 311, 558, 674, 530, 586, 618, 166, 32, 34, 148, 45, 161, 279, 64, 689, 17, 149, 584, 562, 176, 423, 191, 22, 44, 59, 118, 281, 27, 641, 71, 391, 12, 445, 54, 313, 611, 144, 49, 335, 86, 672, 172, 113, 681, 219, 419, 81, 230, 362, 451, 76, 7, 39, 649, 98, 616, 477, 367, 535, 103, 140, 621, 91, 66, 251, 668, 198, 108, 278, 223, 394, 306, 135, 563, 226, 3, 505, 80, 167, 35, 473, 675, 589, 162, 531, 680, 255, 648, 112, 617, 194, 145, 48, 557, 690, 63, 640, 18, 282, 95, 310, 50, 67, 199, 673, 16, 585, 502, 338, 643, 31, 336, 613, 11, 72, 175, 446, 612, 143, 43, 250, 231, 450, 99, 363, 556, 87, 203, 671, 688, 104, 368, 588, 40, 304, 26, 258, 390, 55, 114, 171, 139, 418, 23, 8, 75, 119, 58, 667, 478, 536, 82, 620, 447, 36, 168, 146, 30, 51, 190, 19, 422, 564, 305, 107, 4, 136, 506, 79, 195, 474, 664, 532, 94, 283, 395, 332, 528, 644, 47, 15, 163, 200, 68, 62, 277, 691, 501, 90, 111, 254, 227, 337, 122, 83, 309, 560, 639, 676, 222, 592, 364, 100\n",
      "+-----+--------------------+--------------------+\n",
      "|label|            features|             indexed|\n",
      "+-----+--------------------+--------------------+\n",
      "|  0.0|(692,[127,128,129...|(692,[127,128,129...|\n",
      "|  1.0|(692,[158,159,160...|(692,[158,159,160...|\n",
      "|  1.0|(692,[124,125,126...|(692,[124,125,126...|\n",
      "|  1.0|(692,[152,153,154...|(692,[152,153,154...|\n",
      "|  1.0|(692,[151,152,153...|(692,[151,152,153...|\n",
      "|  0.0|(692,[129,130,131...|(692,[129,130,131...|\n",
      "|  1.0|(692,[158,159,160...|(692,[158,159,160...|\n",
      "|  1.0|(692,[99,100,101,...|(692,[99,100,101,...|\n",
      "|  0.0|(692,[154,155,156...|(692,[154,155,156...|\n",
      "|  0.0|(692,[127,128,129...|(692,[127,128,129...|\n",
      "|  1.0|(692,[154,155,156...|(692,[154,155,156...|\n",
      "|  0.0|(692,[153,154,155...|(692,[153,154,155...|\n",
      "|  0.0|(692,[151,152,153...|(692,[151,152,153...|\n",
      "|  1.0|(692,[129,130,131...|(692,[129,130,131...|\n",
      "|  0.0|(692,[154,155,156...|(692,[154,155,156...|\n",
      "|  1.0|(692,[150,151,152...|(692,[150,151,152...|\n",
      "|  0.0|(692,[124,125,126...|(692,[124,125,126...|\n",
      "|  0.0|(692,[152,153,154...|(692,[152,153,154...|\n",
      "|  1.0|(692,[97,98,99,12...|(692,[97,98,99,12...|\n",
      "|  1.0|(692,[124,125,126...|(692,[124,125,126...|\n",
      "+-----+--------------------+--------------------+\n",
      "only showing top 20 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from pyspark.ml.feature import VectorIndexer\n",
    "\n",
    "data = spark.read.format(\"libsvm\").load(\"/opt/spark/data/mllib/sample_libsvm_data.txt\")\n",
    "\n",
    "indexer = VectorIndexer(inputCol=\"features\", outputCol=\"indexed\", maxCategories=10)\n",
    "indexerModel = indexer.fit(data)\n",
    "\n",
    "categoricalFeatures = indexerModel.categoryMaps\n",
    "print(\"Chose %d categorical features: %s\" %\n",
    "      (len(categoricalFeatures), \", \".join(str(k) for k in categoricalFeatures.keys())))\n",
    "\n",
    "# Create new column \"indexed\" with categorical values transformed to indices\n",
    "indexedData = indexerModel.transform(data)\n",
    "indexedData.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+----------+-----+--------------------+\n",
      "|prediction|label|            features|\n",
      "+----------+-----+--------------------+\n",
      "|      0.05|  0.0|(692,[100,101,102...|\n",
      "|       0.0|  0.0|(692,[121,122,123...|\n",
      "|       0.0|  0.0|(692,[122,123,148...|\n",
      "|       0.0|  0.0|(692,[123,124,125...|\n",
      "|       0.0|  0.0|(692,[124,125,126...|\n",
      "+----------+-----+--------------------+\n",
      "only showing top 5 rows\n",
      "\n",
      "Root Mean Squared Error (RMSE) on test data = 0.149241\n",
      "RandomForestRegressionModel (uid=rfr_55b97471263d) with 20 trees\n"
     ]
    }
   ],
   "source": [
    "from pyspark.ml import Pipeline\n",
    "from pyspark.ml.regression import RandomForestRegressor\n",
    "from pyspark.ml.feature import VectorIndexer\n",
    "from pyspark.ml.evaluation import RegressionEvaluator\n",
    "\n",
    "# Load and parse the data file, converting it to a DataFrame.\n",
    "data = spark.read.format(\"libsvm\").load(\"/opt/spark/data/mllib/sample_libsvm_data.txt\")\n",
    "\n",
    "# Automatically identify categorical features, and index them.\n",
    "# Set maxCategories so features with > 4 distinct values are treated as continuous.\n",
    "featureIndexer =\\\n",
    "    VectorIndexer(inputCol=\"features\", outputCol=\"indexedFeatures\", maxCategories=4).fit(data)\n",
    "\n",
    "# Split the data into training and test sets (30% held out for testing)\n",
    "(trainingData, testData) = data.randomSplit([0.7, 0.3])\n",
    "\n",
    "# Train a RandomForest model.\n",
    "rf = RandomForestRegressor(featuresCol=\"indexedFeatures\")\n",
    "\n",
    "# Chain indexer and forest in a Pipeline\n",
    "pipeline = Pipeline(stages=[featureIndexer, rf])\n",
    "\n",
    "# Train model.  This also runs the indexer.\n",
    "model = pipeline.fit(trainingData)\n",
    "\n",
    "# Make predictions.\n",
    "predictions = model.transform(testData)\n",
    "\n",
    "# Select example rows to display.\n",
    "predictions.select(\"prediction\", \"label\", \"features\").show(5)\n",
    "\n",
    "# Select (prediction, true label) and compute test error\n",
    "evaluator = RegressionEvaluator(\n",
    "    labelCol=\"label\", predictionCol=\"prediction\", metricName=\"rmse\")\n",
    "rmse = evaluator.evaluate(predictions)\n",
    "print(\"Root Mean Squared Error (RMSE) on test data = %g\" % rmse)\n",
    "\n",
    "rfModel = model.stages[1]\n",
    "print(rfModel)  # summary only"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-----+--------------------+\n",
      "|label|            features|\n",
      "+-----+--------------------+\n",
      "|  0.0|(692,[95,96,97,12...|\n",
      "|  0.0|(692,[98,99,100,1...|\n",
      "|  0.0|(692,[122,123,124...|\n",
      "|  0.0|(692,[123,124,125...|\n",
      "|  0.0|(692,[123,124,125...|\n",
      "|  0.0|(692,[124,125,126...|\n",
      "|  0.0|(692,[124,125,126...|\n",
      "|  0.0|(692,[124,125,126...|\n",
      "|  0.0|(692,[124,125,126...|\n",
      "|  0.0|(692,[126,127,128...|\n",
      "|  0.0|(692,[126,127,128...|\n",
      "|  0.0|(692,[126,127,128...|\n",
      "|  0.0|(692,[126,127,128...|\n",
      "|  0.0|(692,[126,127,128...|\n",
      "|  0.0|(692,[126,127,128...|\n",
      "|  0.0|(692,[127,128,129...|\n",
      "|  0.0|(692,[127,128,129...|\n",
      "|  0.0|(692,[128,129,130...|\n",
      "|  0.0|(692,[129,130,131...|\n",
      "|  0.0|(692,[151,152,153...|\n",
      "+-----+--------------------+\n",
      "only showing top 20 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "trainingData.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+----------+------------+--------------------+\n",
      "|prediction|indexedLabel|            features|\n",
      "+----------+------------+--------------------+\n",
      "|       1.0|         1.0|(692,[124,125,126...|\n",
      "|       1.0|         1.0|(692,[124,125,126...|\n",
      "|       1.0|         1.0|(692,[126,127,128...|\n",
      "|       1.0|         1.0|(692,[126,127,128...|\n",
      "|       1.0|         1.0|(692,[126,127,128...|\n",
      "+----------+------------+--------------------+\n",
      "only showing top 5 rows\n",
      "\n",
      "Test Error = 0\n",
      "RandomForestClassificationModel (uid=rfc_d0251ab2d3e1) with 10 trees\n"
     ]
    }
   ],
   "source": [
    "from pyspark.ml import Pipeline\n",
    "from pyspark.ml.classification import RandomForestClassifier\n",
    "from pyspark.ml.feature import StringIndexer, VectorIndexer\n",
    "from pyspark.ml.evaluation import MulticlassClassificationEvaluator\n",
    "\n",
    "# Load and parse the data file, converting it to a DataFrame.\n",
    "data = spark.read.format(\"libsvm\").load(\"/opt/spark/data/mllib/sample_libsvm_data.txt\")\n",
    "\n",
    "# Index labels, adding metadata to the label column.\n",
    "# Fit on whole dataset to include all labels in index.\n",
    "labelIndexer = StringIndexer(inputCol=\"label\", outputCol=\"indexedLabel\").fit(data)\n",
    "# Automatically identify categorical features, and index them.\n",
    "# Set maxCategories so features with > 4 distinct values are treated as continuous.\n",
    "featureIndexer =\\\n",
    "    VectorIndexer(inputCol=\"features\", outputCol=\"indexedFeatures\", maxCategories=4).fit(data)\n",
    "\n",
    "# Split the data into training and test sets (30% held out for testing)\n",
    "(trainingData, testData) = data.randomSplit([0.7, 0.3])\n",
    "\n",
    "# Train a RandomForest model.\n",
    "rf = RandomForestClassifier(labelCol=\"indexedLabel\", featuresCol=\"indexedFeatures\", numTrees=10)\n",
    "\n",
    "# Chain indexers and forest in a Pipeline\n",
    "pipeline = Pipeline(stages=[labelIndexer, featureIndexer, rf])\n",
    "\n",
    "# Train model.  This also runs the indexers.\n",
    "model = pipeline.fit(trainingData)\n",
    "\n",
    "# Make predictions.\n",
    "predictions = model.transform(testData)\n",
    "\n",
    "# Select example rows to display.\n",
    "predictions.select(\"prediction\", \"indexedLabel\", \"features\").show(5)\n",
    "\n",
    "# Select (prediction, true label) and compute test error\n",
    "evaluator = MulticlassClassificationEvaluator(\n",
    "    labelCol=\"indexedLabel\", predictionCol=\"prediction\", metricName=\"accuracy\")\n",
    "accuracy = evaluator.evaluate(predictions)\n",
    "print(\"Test Error = %g\" % (1.0 - accuracy))\n",
    "\n",
    "rfModel = model.stages[2]\n",
    "print(rfModel)  # summary only"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+---+--------+-------------+\n",
      "| id|category|categoryIndex|\n",
      "+---+--------+-------------+\n",
      "|  0|       a|          0.0|\n",
      "|  1|       b|          2.0|\n",
      "|  2|       c|          1.0|\n",
      "|  3|       a|          0.0|\n",
      "|  4|       a|          0.0|\n",
      "|  5|       c|          1.0|\n",
      "+---+--------+-------------+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from pyspark.ml.feature import StringIndexer\n",
    "\n",
    "df = sqlContext.createDataFrame(\n",
    "    [(0, \"a\"), (1, \"b\"), (2, \"c\"), (3, \"a\"), (4, \"a\"), (5, \"c\")],\n",
    "    [\"id\", \"category\"])\n",
    "indexer = StringIndexer(inputCol=\"category\", outputCol=\"categoryIndex\")\n",
    "indexed = indexer.fit(df).transform(df)\n",
    "indexed.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "from pyspark.ml.feature import OneHotEncoder, StringIndexer\n",
    "\n",
    "df = sqlContext.createDataFrame([\n",
    "  (0, \"a\"),\n",
    "  (1, \"b\"),\n",
    "  (2, \"c\"),\n",
    "  (3, \"a\"),\n",
    "  (4, \"a\"),\n",
    "  (5, \"c\")\n",
    "], [\"id\", \"category\"])\n",
    "\n",
    "stringIndexer = StringIndexer(inputCol=\"category\", outputCol=\"categoryIndex\")\n",
    "model = stringIndexer.fit(df)\n",
    "indexed = model.transform(df)\n",
    "#encoder = OneHotEncoder(includeFirst=False, inputCol=\"categoryIndex\", outputCol=\"categoryVec\")\n",
    "#encoded = encoder.transform(indexed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+---+--------+-------------+\n",
      "| id|category|categoryIndex|\n",
      "+---+--------+-------------+\n",
      "|  0|       a|          0.0|\n",
      "|  1|       b|          2.0|\n",
      "|  2|       c|          1.0|\n",
      "|  3|       a|          0.0|\n",
      "|  4|       a|          0.0|\n",
      "|  5|       c|          1.0|\n",
      "+---+--------+-------------+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "indexed.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+--------------------+-----+\n",
      "|            features|label|\n",
      "+--------------------+-----+\n",
      "|[7.4,0.7,0.0,1.9,...|    5|\n",
      "|[7.8,0.88,0.0,2.6...|    5|\n",
      "|[7.8,0.76,0.04,2....|    5|\n",
      "|[11.2,0.28,0.56,1...|    6|\n",
      "|[7.4,0.7,0.0,1.9,...|    5|\n",
      "|[7.4,0.66,0.0,1.8...|    5|\n",
      "|[7.9,0.6,0.06,1.6...|    5|\n",
      "|[7.3,0.65,0.0,1.2...|    7|\n",
      "|[7.8,0.58,0.02,2....|    7|\n",
      "|[7.5,0.5,0.36,6.1...|    5|\n",
      "|[6.7,0.58,0.08,1....|    5|\n",
      "|[7.5,0.5,0.36,6.1...|    5|\n",
      "|[5.6,0.615,0.0,1....|    5|\n",
      "|[7.8,0.61,0.29,1....|    5|\n",
      "|[8.9,0.62,0.18,3....|    5|\n",
      "|[8.9,0.62,0.19,3....|    5|\n",
      "|[8.5,0.28,0.56,1....|    7|\n",
      "|[8.1,0.56,0.28,1....|    5|\n",
      "|[7.4,0.59,0.08,4....|    4|\n",
      "|[7.9,0.32,0.51,1....|    6|\n",
      "+--------------------+-----+\n",
      "only showing top 20 rows\n",
      "\n",
      "+--------------+-----+--------------------+\n",
      "|predictedLabel|label|            features|\n",
      "+--------------+-----+--------------------+\n",
      "|             6|    4|[4.6,0.52,0.15,2....|\n",
      "|             6|    6|[5.0,0.38,0.01,1....|\n",
      "|             6|    6|[5.0,0.74,0.0,1.2...|\n",
      "|             6|    5|[5.0,1.04,0.24,1....|\n",
      "|             6|    6|[5.1,0.47,0.02,1....|\n",
      "+--------------+-----+--------------------+\n",
      "only showing top 5 rows\n",
      "\n",
      "Test Error = 0.412475\n",
      "RandomForestClassificationModel (uid=rfc_d49940e569c0) with 10 trees\n"
     ]
    }
   ],
   "source": [
    "from pyspark.ml import Pipeline\n",
    "from pyspark.ml.classification import RandomForestClassifier\n",
    "from pyspark.ml.feature import IndexToString, StringIndexer, VectorIndexer\n",
    "from pyspark.ml.evaluation import MulticlassClassificationEvaluator\n",
    "\n",
    "# Load and parse the data file, converting it to a DataFrame.\n",
    "#data = spark.read.format(\"libsvm\").load(\"/opt/spark/data/mllib/sample_libsvm_data.txt\")\n",
    "df = spark.read.format('com.databricks.spark.csv').\\\n",
    "                               options(header='true', \\\n",
    "                               inferschema='true').load(\"./data/WineData.csv\",header=True);\n",
    "\n",
    "# convert the data to dense vector\n",
    "def transData(row):\n",
    "    return Row(label=row[\"quality\"],\n",
    "               features=Vectors.dense([row[\"fixed acidity\"],\n",
    "                                       row[\"volatile acidity\"],\n",
    "                                       row[\"citric acid\"],\n",
    "                                       row[\"residual sugar\"],\n",
    "                                       row[\"chlorides\"], \n",
    "                                       row[\"free sulfur dioxide\"],\n",
    "                                       row[\"total sulfur dioxide\"],\n",
    "                                       row[\"residual sugar\"],\n",
    "                                       row[\"density\"],                                          \n",
    "                                       row[\"pH\"],\n",
    "                                       row[\"sulphates\"],\n",
    "                                       row[\"alcohol\"]\n",
    "                                      ]))\n",
    "\n",
    "from pyspark.sql import Row\n",
    "from pyspark.ml.linalg import Vectors\n",
    "\n",
    "data = df.rdd.map(transData).toDF() \n",
    "data.show()\n",
    "# Index labels, adding metadata to the label column.\n",
    "# Fit on whole dataset to include all labels in index.\n",
    "labelIndexer = StringIndexer(inputCol=\"label\", outputCol=\"indexedLabel\").fit(data)\n",
    "\n",
    "# Automatically identify categorical features, and index them.\n",
    "# Set maxCategories so features with > 4 distinct values are treated as continuous.\n",
    "featureIndexer =\\\n",
    "    VectorIndexer(inputCol=\"features\", outputCol=\"indexedFeatures\", maxCategories=4).fit(data)\n",
    "\n",
    "# Split the data into training and test sets (30% held out for testing)\n",
    "(trainingData, testData) = data.randomSplit([0.7, 0.3])\n",
    "\n",
    "# Train a RandomForest model.\n",
    "rf = RandomForestClassifier(labelCol=\"indexedLabel\", featuresCol=\"indexedFeatures\", numTrees=10)\n",
    "\n",
    "# Convert indexed labels back to original labels.\n",
    "labelConverter = IndexToString(inputCol=\"prediction\", outputCol=\"predictedLabel\",\n",
    "                               labels=labelIndexer.labels)\n",
    "\n",
    "# Chain indexers and forest in a Pipeline\n",
    "pipeline = Pipeline(stages=[labelIndexer, featureIndexer, rf, labelConverter])\n",
    "\n",
    "# Train model.  This also runs the indexers.\n",
    "model = pipeline.fit(trainingData)\n",
    "\n",
    "# Make predictions.\n",
    "predictions = model.transform(testData)\n",
    "\n",
    "# Select example rows to display.\n",
    "predictions.select(\"features\",\"label\",\"predictedLabel\").show(5)\n",
    "\n",
    "# Select (prediction, true label) and compute test error\n",
    "evaluator = MulticlassClassificationEvaluator(\n",
    "    labelCol=\"indexedLabel\", predictionCol=\"prediction\", metricName=\"accuracy\")\n",
    "accuracy = evaluator.evaluate(predictions)\n",
    "print(\"Test Error = %g\" % (1.0 - accuracy))\n",
    "\n",
    "rfModel = model.stages[2]\n",
    "print(rfModel)  # summary only"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
